{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "143ca2ff",
   "metadata": {},
   "source": [
    "# 单样本学习和孪生网络\n",
    "ONE SHOT LEARNING（单样本学习） WITH SIAMESE NETWORKS （孪生网络）\n",
    "\n",
    "假设公司有10个人的照片，传统思路是训练一个十分类的网络。但是，如果公司增加新员工或者有员工离职了，网络重训？\n",
    "\n",
    "新的思路： 公司有员工A,B,  用A的两张照片构造一个正样本，用A,B各一张照片构造一个负样本，输入就是两张照片和正负标签，两张照片都过同一个模型产生2个向量，将两个向量得到一个相似度得分，通常用softmax函数将这个得分进行压缩。0表示完全不相似，1表示完全相似，转换成一个二分类问题。\n",
    "模型训练完成后，将公司所有人的照片过下模型得到每个人人脸照片的向量表示，新来员工则在员工人脸向量库中增加该人的人脸向量即可。当员工刷脸时，过一遍模型得到人脸向量，再和数据库中向量比对，找到最近的那个作为最匹配的。\n",
    "\n",
    "除了人脸识别的应用场景外，尝试挖掘一些比较稀少的低质样本。只要人工标注发现一类低质样本，将该低质样本过模型得到向量存入向量库，再去找没有标注的数据哪些跟该样本相似，用于增广该类低质样本，是的最终训练的样本保持均衡。\n",
    "\n",
    "* https://zhuanlan.zhihu.com/p/35040994\n",
    "* https://github.com/adambielski/siamese-triplet\n",
    "* https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b54fd4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[ 0.1675,  0.1254, -0.1997, -0.1224, -0.0815,  0.0223,  0.1118, -0.0283],\n",
      "        [ 0.1791,  0.1299, -0.2088, -0.1007, -0.0982,  0.0162,  0.1128,  0.0007]],\n",
      "       grad_fn=<AddmmBackward>), tensor([[ 0.1699,  0.1224, -0.2072, -0.0944, -0.0685, -0.0172,  0.0924, -0.0622],\n",
      "        [ 0.2047,  0.1322, -0.1831, -0.0924, -0.0988,  0.0184,  0.0962,  0.0227]],\n",
      "       grad_fn=<AddmmBackward>))\n",
      "tensor([[ 0.1664,  0.1250, -0.2121, -0.1156, -0.0835,  0.0174,  0.1190, -0.0162]],\n",
      "       grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "# 1. 定义模型\n",
    "class SiameseModel(nn.Module):\n",
    "    def __init__(self, input_dim,output_dim=8):\n",
    "        super(SiameseModel, self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(input_dim, 25),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(25, 30),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(30, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, input1, input2):\n",
    "        v1 = self.model(input1)\n",
    "        v2 = self.model(input2)\n",
    "#         out = v1 - v2 \n",
    "        return v1,v2\n",
    "    \n",
    "    def predict(self, input0):\n",
    "        # 实际预测的时候，只需要输入一个，得到得分即可\n",
    "        return self.model(input0)\n",
    "\n",
    "# 模型初始化\n",
    "input_dim = 4\n",
    "output_dim = 8 #\n",
    "model = SiameseModel(input_dim,output_dim)\n",
    "\n",
    "# 模型简单测试\n",
    "batch_size = 2\n",
    "input1 = torch.rand(batch_size,input_dim)\n",
    "input2 = torch.rand(batch_size,input_dim)\n",
    "ret = model(input1,input2)\n",
    "print(ret)\n",
    "input0 = torch.rand(1,input_dim)\n",
    "ret0 = model.predict(input0)\n",
    "print(ret0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b14a1663",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. 自定义损失函数\n",
    "class ContrastiveLoss(nn.Module):\n",
    "    \"\"\"\n",
    "    Contrastive loss function.\n",
    "    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n",
    "    \"\"\"\n",
    "    def __init__(self, margin=2.0):\n",
    "        super(ContrastiveLoss, self).__init__()\n",
    "        self.margin = margin\n",
    "        self.eps = 1e-9\n",
    "    \n",
    "    def forward(self, output1, output2, target, size_average=True): \n",
    "        \"\"\"两个向量直接减, 实际测试loss可收敛\"\"\"\n",
    "        distances = (output2 - output1).pow(2).sum(1)  # squared distances\n",
    "#         distances = F.pairwise_distance(output1, output2, keepdim = True) # euclidean distances\n",
    "        losses = 0.5 * (target.float() * distances +\n",
    "                        (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))\n",
    "        return losses.mean() if size_average else losses.sum()\n",
    "\n",
    "#     def forward(self, output1, output2, label):\n",
    "#         \"\"\"两个向量cosin距离，实际测试loss不收敛？\"\"\"\n",
    "#         distances = F.pairwise_distance(output1, output2, keepdim = True) # euclidean distance\n",
    "# #         distances = (output2 - output1).pow(2).sum(1)\n",
    "#         # 下面这行似乎有问题？\n",
    "#         losses = (1-label) * torch.pow(distances, 2) + \\\n",
    "#             (label) * torch.pow(torch.clamp(self.margin - distances, min=0.0), 2)\n",
    "#         return torch.mean(losses)\n",
    "\n",
    "# test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "337d892d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 [7.6 3.  6.6 2.1] [6.9 3.1 5.4 2.1]\n",
      "Labels batch shape: torch.Size([64])\n",
      "Feature1 batch shape: torch.Size([64, 4])\n",
      "Feature2 batch shape: torch.Size([64, 4])\n"
     ]
    }
   ],
   "source": [
    "# 3. 数据集预处理，相同分类的pair为正样本，不同分类的pair为负样本\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "class IrisDataset(Dataset):\n",
    "    def __init__(self,data_type=\"train\"):\n",
    "        assert data_type in ('train','test')\n",
    "        self.labels = {'Iris-setosa':0,'Iris-versicolor':1,'Iris-virginica':2}\n",
    "        self.pd_frame = pd.read_csv(\"./dataset/iris/%s.csv\" % (data_type),header=None)\n",
    "        self.dataset = list(self.gen_pairs())\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.dataset[idx]\n",
    "\n",
    "    def gen_pairs(self): \n",
    "        cnt = len(self.pd_frame)\n",
    "        for i in range(cnt-1):\n",
    "            i_label = self.pd_frame.iloc[i, 4]\n",
    "            i_X = self.pd_frame.iloc[i, 0:4]\n",
    "            for j in range(i+1,cnt):\n",
    "                j_label = self.pd_frame.iloc[j, 4]\n",
    "                j_X = self.pd_frame.iloc[j, 0:4] \n",
    "                yield 1.0 if i_label==j_label else 0,i_X.to_numpy(np.float32),j_X.to_numpy(np.float32)\n",
    "            \n",
    "#测试数据集\n",
    "dataset = IrisDataset() #\n",
    "label,x1,x2 = next(iter(dataset))\n",
    "print(label,x1,x2)\n",
    "\n",
    "train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)\n",
    "train_labels,train_features1,train_features2 = next(iter(train_dataloader))\n",
    "# print(train_labels,train_features1,train_features2)\n",
    "print(f\"Labels batch shape: {train_labels.size()}\")\n",
    "print(f\"Feature1 batch shape: {train_features1.size()}\")\n",
    "print(f\"Feature2 batch shape: {train_features2.size()}\")\n",
    "# print(train_dataloader.batch_size)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b740a48e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:0 batch:0 loss:0.940151035785675 \n",
      "Epoch:0 batch:9 loss:0.6913157105445862 \n",
      "Epoch:1 batch:0 loss:0.6381360292434692 \n",
      "Epoch:1 batch:9 loss:0.3607829809188843 \n",
      "Epoch:2 batch:0 loss:0.38604989647865295 \n",
      "Epoch:2 batch:9 loss:0.16247235238552094 \n",
      "Epoch:3 batch:0 loss:0.21776190400123596 \n",
      "Epoch:3 batch:9 loss:0.20620711147785187 \n",
      "Epoch:4 batch:0 loss:0.23811917006969452 \n",
      "Epoch:4 batch:9 loss:0.15145206451416016 \n",
      "Epoch:5 batch:0 loss:0.10913698375225067 \n",
      "Epoch:5 batch:9 loss:0.23343323171138763 \n",
      "Epoch:6 batch:0 loss:0.07452350854873657 \n",
      "Epoch:6 batch:9 loss:0.08197259902954102 \n",
      "Epoch:7 batch:0 loss:0.19631728529930115 \n",
      "Epoch:7 batch:9 loss:0.10484776645898819 \n",
      "Epoch:8 batch:0 loss:0.15678831934928894 \n",
      "Epoch:8 batch:9 loss:0.13326188921928406 \n",
      "Epoch:9 batch:0 loss:0.1668071746826172 \n",
      "Epoch:9 batch:9 loss:0.07453184574842453 \n",
      "Epoch:10 batch:0 loss:0.1294322907924652 \n",
      "Epoch:10 batch:9 loss:0.12433502078056335 \n",
      "Epoch:11 batch:0 loss:0.13758789002895355 \n",
      "Epoch:11 batch:9 loss:0.11600666493177414 \n",
      "Epoch:12 batch:0 loss:0.14612257480621338 \n",
      "Epoch:12 batch:9 loss:0.1279204934835434 \n",
      "Epoch:13 batch:0 loss:0.11773449182510376 \n",
      "Epoch:13 batch:9 loss:0.07979785650968552 \n",
      "Epoch:14 batch:0 loss:0.08451176434755325 \n",
      "Epoch:14 batch:9 loss:0.06988897919654846 \n",
      "Epoch:15 batch:0 loss:0.10551144182682037 \n",
      "Epoch:15 batch:9 loss:0.09366660565137863 \n",
      "Epoch:16 batch:0 loss:0.07631758600473404 \n",
      "Epoch:16 batch:9 loss:0.09548690915107727 \n",
      "Epoch:17 batch:0 loss:0.07807736843824387 \n",
      "Epoch:17 batch:9 loss:0.052217718213796616 \n",
      "Epoch:18 batch:0 loss:0.0680030882358551 \n",
      "Epoch:18 batch:9 loss:0.09254780411720276 \n",
      "Epoch:19 batch:0 loss:0.09026715904474258 \n",
      "Epoch:19 batch:9 loss:0.0996447280049324 \n",
      "Epoch:20 batch:0 loss:0.10346568375825882 \n",
      "Epoch:20 batch:9 loss:0.09196952730417252 \n",
      "Epoch:21 batch:0 loss:0.09491568803787231 \n",
      "Epoch:21 batch:9 loss:0.08410424739122391 \n",
      "Epoch:22 batch:0 loss:0.06409312784671783 \n",
      "Epoch:22 batch:9 loss:0.06682770699262619 \n",
      "Epoch:23 batch:0 loss:0.11070827394723892 \n",
      "Epoch:23 batch:9 loss:0.1544617861509323 \n",
      "Epoch:24 batch:0 loss:0.0766429677605629 \n",
      "Epoch:24 batch:9 loss:0.04685258865356445 \n",
      "Epoch:25 batch:0 loss:0.07393959909677505 \n",
      "Epoch:25 batch:9 loss:0.07401518523693085 \n",
      "Epoch:26 batch:0 loss:0.05982859060168266 \n",
      "Epoch:26 batch:9 loss:0.08706589043140411 \n",
      "Epoch:27 batch:0 loss:0.08052462339401245 \n",
      "Epoch:27 batch:9 loss:0.07389035820960999 \n",
      "Epoch:28 batch:0 loss:0.06032995134592056 \n",
      "Epoch:28 batch:9 loss:0.04406735301017761 \n",
      "Epoch:29 batch:0 loss:0.0998634323477745 \n",
      "Epoch:29 batch:9 loss:0.07467623054981232 \n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAsI0lEQVR4nO3deXzU1b3/8ddnJvu+TUIWIAthCSBbRETBvYJLsVqt1mrrWlvt7fK7vXq73PZ2b62tt1dbVPRWbV3rUlTcq4ALS0AWWRKSAEkIkIUkZM8kOb8/ZhKHkGUSJpnMzOf5eOTB5DvfmfnkO+GdM+ec7/mKMQallFK+z+LtApRSSnmGBrpSSvkJDXSllPITGuhKKeUnNNCVUspPBHnrhZOSkkxmZqa3Xl4ppXzSli1baowxtv7u81qgZ2ZmUlBQ4K2XV0opnyQiBwe6T7tclFLKT2igK6WUn9BAV0opP6GBrpRSfkIDXSml/IQGulJK+QkNdKWU8hM+F+iFRxr59et7aGyze7sUpZQaV3wu0MuPtfDQ2lL2VTV5uxSllBpXfC7Qp6ZEA1B0pNHLlSil1Pjic4GeER9OeLCVoqPaQldKKVc+F+gWi5CbEkXRUW2hK6WUK58LdHB0u2igK6XUiXw00KOoamynvqXD26UopdS44ZOBntszMKr96Eop1csnA31ab6Brt4tSSvXwyUBPjQ0jOjRIA10ppVz4ZKCLCFN0potSSp3AJwMdHN0u+7QPXSmlevlsoOemRFPb3EFNU7u3S1FKqXHBZwNdB0aVUupEPhvoU1OiAF3TRSmlevhsoNuiQ4kND6ZIV11USinAhwNdRJwDo9pCV0op8OFAB8hNiaLwSCPGGG+XopRSXufTgT41JZrjbZ1UNepMF6WU8vlAB8dl6ZRSKtD5eKA7Z7poP7pSSvl2oCdGhZIUFaJnjCqlFD4e6AC5ydEUagtdKaV8P9CnpkRRXNWkM12UUgHP9wN9QjRN7Z1UNrR5uxSllPIq3w/0njVddKaLUirA+X6gJ+siXUopBX4Q6LERwaTEhOrAqFIq4Pl8oIOj20WnLiqlAp1bgS4iy0SkUESKReSefu6PFZFXRGS7iOwSkZs8X+rApqZEs6+qke5unemilApcQwa6iFiBB4HlQB5wnYjk9dntTmC3MWYOcC5wn4iEeLjWAU1NiaLN3k15XctYvaRSSo077rTQFwLFxphSY0wH8Aywos8+BogWEQGigGNAp0crHURu79WLtNtFKRW43An0dKDc5fsK5zZXDwAzgEpgJ/BtY0x33ycSkdtFpEBECqqrq0dY8slyk3VNF6WUcifQpZ9tfTurLwa2AWnAXOABEYk56UHGPGyMyTfG5NtstmGWOrDosGDS48I10JVSAc2dQK8AJrp8n4GjJe7qJuBF41AM7Aeme6ZE90xNidIuF6VUQHMn0DcDuSKS5RzovBZY3WefMuACABFJAaYBpZ4sdChTU6IpqWqis+uknh6llAoIQwa6MaYTuAt4E9gDPGeM2SUid4jIHc7dfg4sFpGdwLvA3caYmtEquj+5KdF0dHVz8JjOdFFKBaYgd3YyxqwB1vTZttLldiXwOc+WNjzTXNZ0ybFFebMUpZTyCr84UxRgSnIUIjp1USkVuPwm0MNDrExKiKCoSme6KKUCk98EOjiuXqTL6CqlApVfBfrUlCj21zTT0akzXZRSgcevAn3ahGg6uw37a5q9XYpSSo05vwr0XL3YhVIqgPlVoGfbIrEI7NNAV0oFIL8K9LBgK5lJkXr1IqVUQPKrQAfHNUb16kVKqUDkf4E+IZoDtc202bu8XYpSSo0p/wv0lCi6DZRUaytdKRVY/DDQHTNdtNtFKRVo/C7QMxMjCbaKDowqpQKO3wV6SJCFrKRInbqolAo4fhfo4Oh20VUXlVKBxm8DvexYCy0dnd4uRSmlxoyfBrrjAhfFVdpKV0oFDj8NdMdMl0JdSlcpFUD8MtAnJUQAUFHX6uVKlFJq7PhloAdZLcSEBVHf0uHtUpRSasz4ZaADxEeGUNdi93YZSik1Zvw20OMiQqjTFrpSKoD4baDHRwRTry10pVQA8eNA1xa6Uiqw+G2gx2kLXSkVYPw20OMjQmhq76Sjs9vbpSil1Jjw40APBqC+VbtdlFKBwW8DPS4iBEC7XZRSAcNvAz3eGejHmrWFrpQKDH4b6HHOLpc6DXSlVIDw20BPjwsHdD0XpVTg8NtAj48MIT4imNIaXUJXKRUY/DbQAbJtUZRUN3u7DKWUGhP+HehJkZRqoCulAoR/B7otipqmdo636dRFpZT/cyvQRWSZiBSKSLGI3DPAPueKyDYR2SUiaz1b5shk2yIBtJWulAoIQwa6iFiBB4HlQB5wnYjk9dknDvgz8HljzEzgas+XOnw5zkDfrwOjSqkA4E4LfSFQbIwpNcZ0AM8AK/rs82XgRWNMGYAxpsqzZY7MpIRIrBbRFrpSKiC4E+jpQLnL9xXOba6mAvEi8r6IbBGRG/t7IhG5XUQKRKSgurp6ZBUPQ0iQhYnx4RroSqmA4E6gSz/bTJ/vg4AFwKXAxcCPRWTqSQ8y5mFjTL4xJt9msw272JFwTF3ULhellP9zJ9ArgIku32cAlf3s84YxptkYUwOsA+Z4psRTk5UUyYHaZrq7+/4NUkop/+JOoG8GckUkS0RCgGuB1X32+SewRESCRCQCOAPY49lSRybbFkmbvZvKBl0CQCnl34KG2sEY0ykidwFvAlbgMWPMLhG5w3n/SmPMHhF5A9gBdAOrjDGfjmbh7spOigIcUxcz4iO8XI1SSo2eIQMdwBizBljTZ9vKPt/fC9zrudI8I6d3LnoTS6eOTb+9Ukp5g1+fKQpgiw4lKjSI0hqd6aKU8m9+H+giQrZN13RRSvk/vw906FmkS6cuKqX8W2AEui2KyoY2Wjo6vV2KUkqNmgAJdMfA6IGaFi9XopRSoycwAr1n6qIu0qWU8mMBEehZSbqMrlLK/wVEoIeHWEmPC9eBUaWUXwuIQAdHK13noiul/FnABHrPXHRjdJEupZR/CpxAT4qkqb2T6sZ2b5eilFKjInAC3eaY6VKiA6NKKT8VQIHunOmiUxeVUn4qYAI9LTacsGCLTl1USvmtgAl0i0XITNQ1XZRS/itgAh0gxxalUxeVUn4roAI92xZJ+bEW2ju7vF2KUkp5XMAFereB8mO6SJdSyv8EVqAn6dRFpZT/CqhAz7LpIl1KKf8VUIEeExZMUlSoznRRSvmlgAp0cK7pojNdlFJ+KOACPcemc9GVUv4p4AI9OymKuhY7dc0d3i5FKaU8KvACXdd0UUr5qQAMdJ26qJTyTwEX6BPjwwm2ik5dVEr5nYAL9CCrhUkJETowqpTyOwEX6ODodtGpi0opfxOggR7Jwdpmurr1+qJKKf8RkIGekxSFvctQUaeLdCml/EdABrqu6aKU8kcBGejZSY5AL9GBUaWUHwnIQE+IDCE2PFgHRpVSfiUgA11EHIt0aQtdKeVH3Ap0EVkmIoUiUiwi9wyy3+ki0iUiX/RciaMjOylK+9CVUn5lyEAXESvwILAcyAOuE5G8Afb7LfCmp4scDdm2SKoa22lss3u7FKWU8gh3WugLgWJjTKkxpgN4BljRz37fAl4AqjxY36jJcc502a/96EopP+FOoKcD5S7fVzi39RKRdOALwMrBnkhEbheRAhEpqK6uHm6tHtWzSJd2uyil/IU7gS79bOt7iuX9wN3GmK7BnsgY87AxJt8Yk2+z2dwscXRMTozAIujAqFLKbwS5sU8FMNHl+wygss8++cAzIgKQBFwiIp3GmJc9UeRoCA2ykhEfQUlNM+XHWrBFhxIWbPV2WUopNWLutNA3A7kikiUiIcC1wGrXHYwxWcaYTGNMJvAP4JvjOcx7ZNsi2Xe0keX/s54/v1/i7XKUUuqUDBnoxphO4C4cs1f2AM8ZY3aJyB0icsdoFziaspOiKDraRFN7J5+U1Xm7HKWUOiXudLlgjFkDrOmzrd8BUGPM1069rLHRs6YLwO7K4xhjcHYbKaWUzwnIM0V75CR9Fui1zR1UN7Z7sRqllDo1AR3oPVMXe+yqPO6lSpRS6tQFdKCnxIQSGWJldnosALsPa6ArpXxXQAe6iHDLkmxuXZLFpIQIdmsLXSnlw9waFPVn37toKgBvfHqEXZUNXq5GKaVGLqBb6K7yUmM4UNtCU3unt0tRSqkR0UB3ykuLAWCv9qMrpXyUBrpTT6DrwKhSyldpoDtNiAkjITKEXYc00JVSvkkD3UlEyEuN0Ra6UspnaaC7yEuLofBoI/aubm+XopRSw6aB7iIvNYaOzm5KdI10pZQP0kB3MbNnYFRPMFJK+SANdBdZSZGEBlk00JVSPkkD3UWQ1cL0CdE6MKqU8kka6H3kpcWyy7k2ulJK+RIN9D7y0mJoaLVT2dDm7VKUUmpYNND7yEvVgVGllG/SQO9j+oRoRNCVF5VSPkcDvY/I0CCykiK1ha6U8jka6P3QJQCUUr5IA70feWkxVNS10tBq93YpSinlNg30fsxMc15jVLtdlFI+RAO9H70zXbTbRSnlQzTQ+2GLDsUWHaotdKWUT9FAH0BeaoxOXVRK+RQN9AHMTIuhuKqJ9s4ub5eilFJu0UAfQF5aDJ3dhn1HdW10pZRv0EAfgA6MKqV8jQb6ADITI4kIsY5oYHRDaS1PbyobhaqUUmpgGugDsFiEGakxIwr0Vev384OXdnrkUnYdnd08+F4xdc0dp/xcSin/poE+iJ4lALq7h7c2ekVdC8bAn98rOeUa3t1zlHvfLGTdvupTfi6llH/TQB9EXloMTe2dlNe1DOtxh+pbCbYKL287RFnt8B7b16s7DgOOlrpSSg1GA30QfS8a3Wbvorm9c9DHNLTaaWzr5KtnZmK1CH9ZO/JWenN7J+/uPQpAR5cGulJqcEHeLmA8m5oSjdUi7D58nOWzU7n18QI2lNYyf3I8S3OTWJJrY1Z6LFaL9D6mwtmanz85nvbObp7ZXMa3zp9CWlz4sF//X3uraLM7gtyuLXSl1BDcaqGLyDIRKRSRYhG5p5/7rxeRHc6vj0RkjudLHXthwVZybJ+tjb6/ppmJCRE0t3fy+7eKWPHgh+T/4m3uemorz20up7K+lYq6VgAy4sP5+jnZGAMPrysd0eu/uqOSmDDH31xtoSulhjJkC11ErMCDwEVABbBZRFYbY3a77LYfOMcYUyciy4GHgTNGo+Cxlpcaw4bSYwDUt3TwpdMn8V+X51HT1M6HxTWsK6ph/b7q3r7ungDOiI8gITKEq+Zn8PSmMr55Xg7J0WFuv25jm533Cqv54oIMntpYhr1LL1qtlBqcO10uC4FiY0wpgIg8A6wAegPdGPORy/4bgAxPFulNM9NieXlbJUePt9Hc0UV8RDAASVGhrJibzoq56RhjKDraxPp91azbV4NA737fODeH57eU88i6Un54aZ7br/vOnqN0dHbzhXnpPLWxTAdFlVJDcifQ04Fyl+8rGLz1fQvw+qkUNZ7kOQdGPy6pBSDOGdSuRIRpE6KZNiGaW5dkn3BfZlIkK+am87cNZXzj3CkkRIa49bqvbj9MamwYCybFE2wVt7tc3ttbxaLsRMJDrG7tr5TyH+70oUs/2/r9/C8i5+EI9LsHuP92ESkQkYLqat+YV92zBMCHxTUAxEa4F8iu7jwvh7bOLh79wL2+9IYWO+v2VXPp7FQsFiHEanFrULTqeBs3/XUzf3i7cNg1KqV8nzuBXgFMdPk+A6jsu5OInAasAlYYY2r7eyJjzMPGmHxjTL7NZhtJvWMuPjKEtNgwPnK20OP7aaEPZUpyNJfMSuXxjw7S0DL0Ze3e2n0Ee5fhsjlpAAQHWbC70UI/3uaYUvn0pnK9fJ5SAcidQN8M5IpIloiEANcCq113EJFJwIvADcaYIs+X6V15aTEcqnfMXokLH34LHeDO86bQ1N7JXz86MOS+r+44zMSEcOZkOC6FF2y1uNXl0rPUb1N7J09t1LVklAo0Qwa6MaYTuAt4E9gDPGeM2SUid4jIHc7d/gtIBP4sIttEpGDUKvaCnm4X6L8P3a3nSIvhwhkpPPbhfhrbBm491zV38GFxDZfOTkPE0dsVYrXQ0Tn0LJeeOevRoUE89uF+XctdqQDj1jx0Y8waY8xUY0yOMeaXzm0rjTErnbdvNcbEG2PmOr/yR7PosdYzMAojD3SAb50/hYZWO3/bMHDr+Y1dR+jsNlx2WmrvtpAgN1vodkeAf+2sTKob23n5k0MjrlUp5Xv01H83zExzdH0EWYSo0JGfXDtnYhxLp9pYtb6Ulo7+lxB4dUclmYkRvcsOAARbxa1B0TZni/yCGSnMTIvhoXWlw15YTCnluzTQ3ZARH050aBBxEcG93SAj9a3zp1Db3MHTm8pPuq+6sZ2PS2q57LS0E14nxM1B0Z4ul7BgC18/J4fS6mbe2XP0lOr97Rt7+eVru4feUSnldRrobhARZqTFEDeCKYt9nZ6ZwKLsBB5aW0Kb/cQ+7jd2HaHbwGVzUk/Y7u6gaM/zhQZZuWTWBDLiw3lohMsO9Hjz0yM8+sH+U141Uik1+jTQ3fTDS2bw08tneuS5vnV+LlWN7TxfcGIr/dXtlUxJjmJaSvQJ24OtFrfOFHVtoQdZLdy2JJstB+soOHDspH0fXlfC4x8dGLTlb4zhUH0r3Qa359ArNRy1Te10abegx2igu2nOxDjOzk3yyHMtzklk/qQ4/vJ+SW9QHz3exqYDx7jstNSTunVC3exy6ZnVEhbkOEv06vwM4iOCWbn2xDA2xnDvm4X8ZPUuPvfHdby56wjGnPyf6lhzB+2d3USEWHmuoEKvmqQ8am1RNWf++l88+fEBb5fiNzTQvUBE+OriTCob2ig62gjAmp2HMQYuOy3tpP3d73LpaaE7Aj0iJIgbz8zknT1HKa5q7N2vuaMLe5dh2cwJWC3C15/cwpce3sCOivoTnq+yvg2Ab5yTQ6u9iyc3HBzRzzsSxhg2lNbSNMT688o3fVxSy+1PFNDR1c3h423eLsdvaKB7ycSECMAxEAqOk4mmT4hmSnLUSfs6Zrm4Mw+9pw/9s7f1xjMnExZsOWEJ356W9vkzknnj20v4+RWzKKlq4vMPfMh3nvmk9ySqnn/Pm57MudNsPP7RgZP6/UdLwcE6rn14A2f+6l1+9spuDtY2j8nrqtG35eAxbnl8M5MSIogKDRryojHKfRroXmKLCgUcgV5Z38qWg3VcPufk1jlASJDVvRZ6ZxchQRYsLhfcSIwK5Zr8ibz0ySGOOltCdS2OQI+PCCHIauGGRZN5//vn8s1zc1jz6RHO//37/O6Nvb2fHtLiwrl9aTa1zR28sLXilH7u4212XtxaMeR/4pIqxwW2z8hO4ImPD3Du79/n1sc382FxTb/dQ8o37Kio52uPbSY5OpS/33oGcRHBNLfrCXCeooHuJbZoZ6A3tbNmp2Mt9Utnp/a7b7BV3BoUbbd3ExZ08lt669nZdHUbHvtwP+DoGwdIiPzsJKnosGD+Y9l03vv3c1k+awJ/fr+EP7xdRFiwhfiIYM7MTmR2eiyr1u8/pUGsVetK+d5z21nyu/d4eF0JrR39/2cuO9ZCkEV46IZ8PrznfO46bwqflNVz/aqNXHz/Op7aWDbgY9X4tOfwcW54dBOxEcE8ddsikmPCiAoN0m41D9JA95KwYCvRYUFUN7bzyo7DzEqPITMpst993R0UbbN39fafu5qUGMEls1N5akMZx9vs1DsXCOtvGmZ6XDj3XzuPf955FotzEjlvWjIigohw+9Js9tc08/buk+e2bz5wjH9/fjv3vVXIM5vKWL+vmtLqppO6aN7ZU8W0lGhmpsXwqzV7WXrve/zfh/tP2q+8rpX0+HCsFiElJoz/97lpfHjP+dz7xdMIslj4wUs7WfTrd/n163uob9HB2vGuuKqRr6zaSHiwladvW9R7ScZI7XLxKL2mqBfZokPZWlbHjooG7lk+fcD93F+cq7vfQAf4+tIcXt1xmKc3lhHibMUnDDKvfs7EOJ66bdEJ25Y757Y/vK6EZbMmnHDfs5vLeWFrBQL0bcAnRYWSHh9OWmwYuw8f557l07njnBw2HzjGfW8V8t+v7ObhdaXced4UrsmfSEiQhbJjLUxyjjP0CAu2cnX+RL64IIPNB+r460f7eWRdKYfr2/jTdfOGPD7KOw7UNPPlRzYiIjx12xm940fgCHRdGdRzNNC9yBYVysb9jjniA3W3gCPQ3Tr13951woCoq9kZsZw1JZHHPtzPF+ZlIAIx4cNblybIauHWs7P46Su7KThwjPzMhN776ls6mD4hhtV3ncWRhjYO1bdSWd/KobpWDtU7vgqPNJISE9r7s56emcAzt5/JR8U13Pd2ET96+VNWri3h387PpfxYCxfPnNBvHSLCwqwEFmYl8JN/fsrTm8upb+nwyIlfyrMq6lq4ftVG7F3dPHP7mWTbThz0jwq1UukcfFenTgPdi5Kc/ehzJ8ad0Grpy93FuQbqcunx9aU53PjYJp4vKCcuPBirZfjLGFxz+kTuf3cfD60r7RPoduLCgwm2WpiYEDHoz9PX4ilJnJmTyPtF1fzx7SL+44UdACe10PvzpdMn8fjHB/nZq7v53VWnEWTVXsTx4khDG9ev2sjxNjtP37aIaROiT9onMkS7XDxJf/u9qGemi+vKiv0Jtlqwd5khZ3e02bsJCx74LV2Sm8SM1BhqmzuIH2FrNiIkiBsWTeadPUcpqW7q3V7X0kF85MhXohQRzpuWzD/vPIuHb1jAhTNSuHBG8pCPy0uL4dsX5PLi1kN8/cktOlA6TjS02rl+1QZqGtt5/OaFzEqP7Xe/SB0U9SgNdC+amBCB1SJcOkSgh1gdLWl71xCB3jl4C11EuOMcxzVPT2UZ4BvPzCTYamHV+v292xpa7cSO8OIffWv83MwJrPpqPrkpJ7fo+vPdi6by8ytm8a/CKr68asMpn9F6sLaZtUXV+sfhFLxfWEVJdTP/++V5zJ8UP+B+PfPQdSqqZ2iXixddf8YkluQmkRobPuh+PYOYHV3dvbf702bvJjFy8ItDXzo7lfveKiI93v0ukb5s0aFcNT+DF7ZW8L2LppIUFUJ9i31El+fzlBsWTcYWFcK/PbONq1Z+xBM3LyRjhD/jj/+5i3VF1YQEWTgjK4Fzpto4Z6qNKclRp7za5lhqaLHznWc/obqpnbOn2Fiam8SCzHhCg0b/AuI9J8wNFubgaKF3G8fvrj9c2Ly1o8urP4cGuheFBVuZ6kYrNNjZL2zv7IbQgfdrt3cN2uUCjoHNF7+5uPc5R+q2JVk8s7mMJz4+wO1Ls+nsNqfU6veEZbNSefLmEG59ooCr/vIRj9+8kOkTYoZ+YB8Vx1qYPymO+ZPiWVtUzS9e28MvXttDWmwY50xzhPviKUnEhA398xpjqG3uIClqkDfOQw43tLL1YD1by+r4pKyOTw8dB2DOxFhWrS9l5doSwoItnJGVyJLcJJZOtZHr8kfKGOOxP1g1TR0EW4XYIQbeo0Id4dfU3unzgb6rsoHL/vcDLp2dyg8vnTFkQ200aKD7gN5AH2JgdLBpi648ES7ZtigumpHCkxsO9p7hOh5mmZyRncjzd5zJ1x7bzNUrP+aRG/NZlJ3o9uONMRxuaOP86cn86LI8foRjCYR1RdWsLazm1e2HeXpTOVaLcOGMZP73uvn9fmrq7Orm9U+P8Mj6UnZUNPDCN85kweSEk19wmOxd3byyvRJbdCiRoUFsPVjHJ2WOED/c4DgTOCTIwmnpsXztrEwunZ3KnIlxNLV3srG0lvX7ali/z/FHitf2kBITypJcG/trmtl7+Dj5mQkszklkcU4SeWkxIxo4B6hpaicxMnTIPxCRzgvGNLd39p5s56tKq5sxxrEu07/2VvFvF+Ry81lZg36q9jQNdB/Q8wvRPsTUxTY3Wuie9PVzsnlr99HedWLihjkNcrRMnxDDC99czI2PbuTGRzdx/7VzuWSQaaGuGlrttNq7mBAb1rstPS6c6xZO4rqFk7B3dfNJWT3/2FLOcwUV7DxUf0JQN7d38lxBOY9+sJ+KulaykiIRgQ/21Z5yoK8tquZnr+yipPrEdW3S48JZMDme+ZPimT85nrzUmJNCJCo0iAtmpHDBjBTA8Ufqg33VrNtXwzt7jtLdbbj0tFS2ltXz69f3AhATFsQZ2YkszknkzJxEpiZHn7CsxGBqmtpJih76D3xPoI/1wGhDq53DDa0j+gQ3kHrnfPrn71jMyrUl/Ob1vTxXUM5/f34mS3JtHnudwWig+4AQN1ronV3dtNq7xqR/tMeCyQksmBzPi871XeIjvd9C75EeF84/7ljMLY9v5s6ntvKTy/I4UNtCtzFcPieNBZPi+w2nnlbuQB+Xg60WFmYlkJUUyXMFFWw5WMeCyQlUHW/j8Y8P8LcNZTS02smfHM9/XZbHhTNSuORP69lSVjfin+VATTO/eG037+ypYnJiBA/dsICmtk4iQqzMnxxPSkzY0E/SR3pcOF86fRJfOn0SXd2OGVQ9Uz6rjrfxcWktH5fU8lFJbe+ZwYmRIZwzzcZvrjxtyFZnTVO7W58Eo1xa6GPBGMOanUf4yepdHGtuZ/VdZw84A2e4GpxnLM9Kj+GRG/N5r7CK/169ixse3cTyWRP40WV5pMeNbjeMBroP6Oly6ejqxhhDVWM7ew4fp/BII4VHGtl7pJHiqiY6urqH7LP0tNuXZvP1J7cA46eF3iM+MoS/37qIbz39CT99xXEZvRCrhSc+PkhabBiXz0nj8jlpzEyL6e0aONIT6HGDh6QtOpTMxAje3n2U4qomXv6kEnt3N8tmTuDWJdksmPzZYOD8yfG8sq2Srm4zrC6MpvZOHnyvmEfX7yfYKty9bDo3n53p8T/ajpo+qys5JowVc9NZMTcdcJwc9HFJLW/uOsKLWw/x5YWTTjgHoT81jR1utX57u1wGuMauJx093saPX/6Ut3YfZVZ6DMYYbvrrZvInx5Nji2JKchQ5tiiybZG9dQ1HfYudiBBr7/tz3rRkzvxOIqvWl/LAe8W8X1jNXedP4dYlWaPW8NJA9wE9raHvPrudww2tvWuxAKTEhDJ9QgxLcpOYNiGazw1wduVouWhGCtlJkZTWNBPr5UHR/oSHWFn5lfn8/q0iIkOs3HR2Fu/sPsrq7ZU8+sF+HlpXSnZSJJfPSePzc9OobHCctZgaO3Srd/7keF7ceoidhxr40ukTueXsrH7X41kyJYmnNpZx9cqP+MUVs8lLGzzoursNL287xG9e30tVYztXzk/nnmXTSR5BS9wTMuIjuDo/gqVTbbyz5112HmoYNNAdA8HutdB7GiCv7TjC4pykQceA1hU51gfqmeDYM9Pxs+8dt+ZOjDuhPmMMz24u55dr9tDR2c1/Lp/OLWdnseVgHY+s30/hkUbe2n30hEXn0mLDyHEGfF5aDF+cnzFkd1N9q/2kRk1YsJW7zs/linnp/PK1Pdz7ZiH/2FLBz1aMTjeMBroPmJIcRXpcOGHBFpbPmsD0CTFMmxDN9AnRXh+ItFiE7188jZXrSgddG8abgqyWE9bKuWJeOlfMS6euuYPXPz3CK9sr+dO/9vE/7+4jOjQIi3x20tdg7jpvCrPTY1kxN52EQbqbls2awH1Xz+FXa/Zw+QMf8LXFmXz3oqm93Q2utpfX89NXdvFJWT1zJsbx0A0LmDfE1L+xkhIThi06lJ2HGgbdr6HVjr3LkBQ19O9DZmIEN5+VxWMf7mdXZQP3Xzu335Z9R2c3tzy+echzMcAx6L/pBxdgsQgHapr5zxd38nFpLYuyE/jNlaf1/tE9IzuRM5wD5h2d3ZQda6a4qomS6mZKqpooqW7i+YJymju6yIgLZ/GUwa9YVt9iJ3aA/wMZ8RH85SsLWFdUzU9X72LP4eOjEujirQn9+fn5pqCgwCuvrVRfR4+38eqOw6zeXklCRDD/d9NCj79GfUsHv3uzkKc3lZEcHcpPLp/J8lkTert7Sqqb+Nwf15EQGcLdy6Zz5bx0twchx8pN/7eJ0ppmvn/xNBpa7TS02jne2un4t83O8VY7NU0d7Dl8nP+5dm5vt81Q3ius4vvP7+B4m527l03npsWZJ/zsJdVNXHDfWn5+xSwudzkRT3q6ipz/vLbjMD94aSf/vPMsNu6v5Q9vFxFssfCDS2fwpfyJwz6eLR2dzP3Z29ywaDI/vixv0H2/+JePCLZaePr2RYPu19HZjQgjnjosIluMMfn93actdKVwtD5vOTuLW87OGrXXiIsI4VdfmM3VCzL44Uuf8s2/b+WcqTZ+tmImkxMjWVtYTVe34cVvLB7WWjhjacHkeN4rrOaupz7p3dYz3zwmPJiYsGCSo0PJS83g7CFatK7Om5bMm99Zwt0v7OTnr+7m/cIqfn/1nN4B3wM1jpk9eakxg34qvTAvmR+8BF9ZtZHG9k4unJHCL66YdcKspeGICAliUXYi7+45yrcvzB303IP6Vju5/VxxrK/RnMaoga7UGJs3KZ7Vd53FkxsOct9bRVz0x3Xcee4UtpXXMWmYC5uNtduWZrN4ShJRoUGOEA8LJizY4pETkhKjQnnkxgU8tamMn7+6m4vvX8dvrpzNslmp7HcGevYA1wzokRwdRv7keA7UNvPrq+Zx6eyTL7o+XF+Yl8Z3n93Owl++wyWzUrnm9ImckZVw0vPWt9i93gWqga6UFwRZLdx0VhaXzE7l56/u5o/vFAFwTX6GlysbXGiQdcjT+U+FiHD9GZNZlJ3Id5/dxh1/28o1+RnYuwyx4cFuTY39680LsYp47MzTL8zLIDspimcLynllWyUvfnKIzMQIrs6fyFXzM5gQG4YxhobWDq+fLa2BrpQXpcSE8cCX53NNfjUPvFfM1fkTvV3SuJBji+KFbyzm/neK+PP7JRjjuOiKO/obbD5VcybGMWdiHD++NI/XPz3Ms5vLuffNQu57q5BzpyWzYm4a9i7j9am7GuhKjQNLp9pYOnVszib0FcFWC9+/eDrnTE3m7hd2sDjH/SUcRkt4iJUr52dw5fwMDtQ081xBOf/YUsG/9lYBjPl5IH3pLBellDoFnV3drNvnWOvnzvOmjPr5AjrLRSmlRkmQ1cL501M4f3qKt0vRC1wopZS/0EBXSik/oYGulFJ+QgNdKaX8hFuBLiLLRKRQRIpF5J5+7hcR+ZPz/h0iMt/zpSqllBrMkIEuIlbgQWA5kAdcJyJ9V6lZDuQ6v24H/uLhOpVSSg3BnRb6QqDYGFNqjOkAngFW9NlnBfCEcdgAxImIe9f8Ukop5RHuBHo6UO7yfYVz23D3UUopNYrcObGov6XK+p5e6s4+iMjtOLpkAJpEpNCN1+9PElAzwseOlfFe43ivD8Z/jVrfqRvvNY7H+iYPdIc7gV4BuK4YlAFUjmAfjDEPAw+78ZqDEpGCgU59HS/Ge43jvT4Y/zVqfaduvNc43uvry50ul81ArohkiUgIcC2wus8+q4EbnbNdFgENxpjDHq5VKaXUIIZsoRtjOkXkLuBNwAo8ZozZJSJ3OO9fCawBLgGKgRbgptErWSmlVH/cWpzLGLMGR2i7blvpctsAd3q2tEGdcrfNGBjvNY73+mD816j1nbrxXuN4r+8EXls+VymllGfpqf9KKeUnNNCVUspP+FygD7WuzCi+7kQReU9E9ojILhH5tnP7T0XkkIhsc35d4vKY/3TWWSgiF7tsXyAiO533/Uk8ccl0x/MecD7vNhEpcG5LEJG3RWSf8994l/3Hur5pLsdpm4gcF5HvePMYishjIlIlIp+6bPPYMRORUBF51rl9o4hkeqC+e0Vkr3PdpJdEJM65PVNEWl2O40qXx4xKfYPU6LH3dJSO4bMutR0QkW3ePIYeY4zxmS8cs2xKgGwgBNgO5I3Ra6cC8523o4EiHGvb/BT49372z3PWFwpkOeu2Ou/bBJyJ44Ss14HlHqrxAJDUZ9vvgHuct+8Bfuut+vp5L4/gOEnCa8cQWArMBz4djWMGfBNY6bx9LfCsB+r7HBDkvP1bl/oyXffr8zyjUt8gNXrsPR2NY9jn/vuA//LmMfTUl6+10N1ZV2ZUGGMOG2O2Om83AnsYfHmDFcAzxph2Y8x+HFM6F4pjjZsYY8zHxvEb8ARwxSiWvgJ43Hn7cZfX8nZ9FwAlxpiDQ9Q+qjUaY9YBx/p5XU8dM9fn+gdwwXA+TfRXnzHmLWNMp/PbDThO5BvQaNY3UI2DGBfHsIfzea4Bnh7sOUb7GHqKrwX6uFgzxvmRah6w0bnpLufH38dcPp4PVGu683bf7Z5ggLdEZIs4llkASDHOk7yc/yZ7sT5X13Lif6LxcgzBs8es9zHOEG4APHn5+ptxtBZ7ZInIJyKyVkSWuNTgjfo89Z6OZo1LgKPGmH0u28bTMRwWXwt0t9aMGdUCRKKAF4DvGGOO41gqOAeYCxzG8fENBq51NH+Gs4wx83EsZ3yniCwdZF9v1Od4YccZx58HnnduGk/HcDAjqWfUahWRHwKdwN+dmw4Dk4wx84DvAU+JSIyX6vPkezqa7/d1nNiwGE/HcNh8LdDdWjNmtIhIMI4w/7sx5kUAY8xRY0yXMaYbeARHt9BgtVZw4kdkj/0MxphK579VwEvOWo46Py72fGys8lZ9LpYDW40xR531jptj6OTJY9b7GBEJAmJxv3tiQCLyVeAy4HpnFwDOboxa5+0tOPqnp3qjPg+/p6N1DIOAK4FnXeoeN8dwJHwt0N1ZV2ZUOPvEHgX2GGP+4LLddd33LwA9I+mrgWudI+BZOC7+scn5Eb5RRBY5n/NG4J8eqC9SRKJ7buMYOPvUWcdXnbt91eW1xrS+Pk5oFY2XY+jCk8fM9bm+CPyrJ4BHSkSWAXcDnzfGtLhst4njgjSISLazvtKxrs/5+p58T0elRuBCYK8xprcrZTwdwxHx1mjsSL9wrBlThOMv5w/H8HXPxvExagewzfl1CfAksNO5fTWQ6vKYHzrrLMRlFgaQj+MXvAR4AOcZu6dYXzaO2QPbgV09xwZHX967wD7nvwneqM/luSOAWiDWZZvXjiGOPyyHATuOltYtnjxmQBiOrqViHLMksj1QXzGOPtue38OeGRZXOd/77cBW4PLRrm+QGj32no7GMXRu/ytwR599vXIMPfWlp/4rpZSf8LUuF6WUUgPQQFdKKT+hga6UUn5CA10ppfyEBrpSSvkJDXSllPITGuhKKeUn/j+M3NG8Ajq4agAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "criterion = ContrastiveLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)\n",
    "\n",
    "\n",
    "counter = []\n",
    "loss_history = [] \n",
    "def train(epoch,dataloader, model, loss_fn, optimizer):\n",
    "    model.train() #训练模式\n",
    "    \n",
    "    dataset_cnt = len(dataloader.dataset)\n",
    "    batch_cnt = len(dataloader)\n",
    "#     print(dataset_size,batch_size) \n",
    "    \n",
    "    for batch, (label,x1,x2) in enumerate(dataloader): \n",
    "        output1,output2 = model(x1,x2)\n",
    "        loss_contrastive = criterion(output1,output2,label)\n",
    "        optimizer.zero_grad()\n",
    "        loss_contrastive.backward()\n",
    "        optimizer.step()\n",
    "        if batch % 10 == 0 or batch==(batch_cnt-1):\n",
    "            print(\"Epoch:{} batch:{} loss:{} \".format(epoch,batch,loss_contrastive.item()))\n",
    "            iteration_number = epoch*dataset_cnt + dataloader.batch_size * batch\n",
    "            counter.append(iteration_number)\n",
    "            loss_history.append(loss_contrastive.item()) \n",
    "\n",
    "epochs = 30\n",
    "for epoch in range(epochs): \n",
    "    train(epoch,train_dataloader, model, criterion, optimizer)\n",
    "#     test(test_dataloader, model, criterion)\n",
    "\n",
    "torch.save(model, 'siamese_module.pth')\n",
    "\n",
    "def show_plot(iteration,loss):\n",
    "    plt.plot(iteration,loss)\n",
    "    plt.show()\n",
    "show_plot(counter,loss_history)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6f66d893",
   "metadata": {},
   "outputs": [],
   "source": [
    "class IrisPredictDataset(Dataset):\n",
    "    def __init__(self,data_type=\"train\"):\n",
    "        assert data_type in ('train','test')\n",
    "        self.labels = {'Iris-setosa':0,'Iris-versicolor':1,'Iris-virginica':2}\n",
    "        self.pd_frame = pd.read_csv(\"./dataset/iris/%s.csv\" % (data_type),header=None)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pd_frame)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        label = self.pd_frame.iloc[idx, 4]\n",
    "        X = self.pd_frame.iloc[idx, 0:4]\n",
    "        return X.to_numpy(np.float32),self.labels[label]\n",
    "\n",
    "train1_dataloader = DataLoader(IrisPredictDataset(\"train\"), batch_size=64)\n",
    "test1_dataloader = DataLoader(IrisPredictDataset(\"test\"), batch_size=64)\n",
    "train_x,train_label = next(iter(train1_dataloader))\n",
    "test_x,test_label = next(iter(test1_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "de8b2e43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([30, 1]) torch.Size([30, 1])\n",
      "tensor(96.6667)\n",
      "tensor([ True,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True])\n"
     ]
    }
   ],
   "source": [
    "model_x = torch.load('siamese_module.pth')\n",
    "\n",
    "# 输出向量结果\n",
    "def compute_vector(model,dataloader):\n",
    "    def compute():\n",
    "        model.eval()\n",
    "        batch_cnt = len(dataloader) \n",
    "        for batch, (x1,label) in enumerate(dataloader): \n",
    "            #其实不需要训练，直接用iris原始的特征计算相似度也能到93%。\n",
    "            #孪生网络对于图像，文本作用更好些\n",
    "#             for i,r in enumerate(x1): \n",
    "#                 yield torch.tensor([label[i]]),r\n",
    "            ret = model.predict(x1) \n",
    "            for i,r in enumerate(ret):\n",
    "                yield torch.tensor([label[i]]),r\n",
    "                \n",
    "    l = list(compute())\n",
    "    labels = torch.stack([label for label,vec in l],0) #cat和stack的区别\n",
    "    vectors = torch.stack([vec for label,vec in l],0)\n",
    "    return labels,vectors\n",
    "\n",
    "train_label,train_vecotors = compute_vector(model_x,train1_dataloader)\n",
    "test_label,test_vecotors = compute_vector(model_x,test1_dataloader)\n",
    "\n",
    "\n",
    "# https://pytorch.org/docs/stable/generated/torch.cdist.html\n",
    "matrix_dist = torch.cdist(test_vecotors, train_vecotors,  p=2)\n",
    "min_idxs = matrix_dist.argmin(dim=1) #距离最小的那个 topk,最小的n个\n",
    "most_close_sampe = train_label.index_select(0,min_idxs)\n",
    "print(most_close_sampe.shape,test_label.shape)\n",
    "auc = (most_close_sampe==test_label).sum()*100/test_label.shape[0]\n",
    "print(auc)\n",
    "# print(most_close_sampe)\n",
    "# print(test_label)\n",
    "print(most_close_sampe[:,0].eq(test_label[:,0]))\n",
    "\n",
    "# 调整训练集大小\n",
    "# train:120 acc=96.67%\n",
    "# train:12 acc=83.33%\n",
    "# train:24 acc=86.67 ~ 93.33%\n",
    "# train:36 acc=93.33 ~ 96.67%  (分类模型，60.0%)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13fa8e64",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
